#include "asan_shadow_memory.h"
#include "asan_mapping.h"
#include "asan_thread.h"
#include "asan_utils.h"
#include "sanitizer_libc.h"
#include "sanitizer_linux_syscall.h"
#include "asan_interceptors.h"
#include "stdio.h"
#include "asan_internal.h"
#include <sys/errno.h>

namespace __sanitizer {

static su_t UNINITIALIZED_SHADOW = -1;

u64 MmapFixed(uptr fixed_addr, uptr size, int additional_flags, const char* name) {
  // Rounding must be down before calling this function
  size = RoundUpTo(size, ASAN_PAGE_SIZE);
  fixed_addr = RoundDownTo(fixed_addr, ASAN_PAGE_SIZE);

  // Using mmap to map the virtual memory address
  static const int buff_size = 256;
  char filepath[buff_size] = "/dev/shm/";
  Tid tid = GetCurrentTid();
  Tid pid = GetCurrentPid();
  long long ret = 0;
  // ret = REAL(snprintf(filepath, buff_size, "/dev/shm/%s-%u-%u.asan.sd", name, pid, tid));
  // ASSERT((ret > 0 && ret < buff_size), "Failed to create the shadow memory file. 1\n");
  internal_strncat(filepath, name, buff_size - 70);
  char pidbuff[30] = "-";
  internal_lltoa(pid, pidbuff + 1, 16);
  internal_strncat(filepath, pidbuff, 30);
  internal_lltoa(tid, pidbuff + 1, 16);
  internal_strncat(filepath, pidbuff, 30);
  if (asan_interceptor_inited)
    VReport(ASAN_LOG_DEBUG, "Mmap %s %llx %llx\n", filepath, fixed_addr, size);
  u64 fd = internal_openat(0, filepath, O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, S_IRWXU);
  ASSERT((fd >= 0), "Failed to create the shadow memory file. 2\n");
  ret = internal_ftruncate(fd, size);
  ASSERT((ret == 0), "Failed to create the shadow memory file. 3\n");
  ret = internal_unlink(filepath);
  ASSERT((ret == 0), "Failed to create the shadow memory file. 4\n");
  ret = internal_mmap((void *)fixed_addr, size, PROT_WRITE | PROT_READ,
                      MAP_PRIVATE | MAP_FIXED | MAP_ANON | additional_flags, fd, 0);
  if (ret == -1) {
    VReport(ASAN_LOG_ERR, "Failed to map shadow memory file to virtual memory. errno=%i\n", errno);
    ASSERT(false, "");
  }
  return fd;
}

int AsanShadowMem::InitShadow() {
  scale = ASAN_SHADOW_SCALE;
  scale_factor = 1 << scale;
  half_scale_factor = 1 << (scale - 1);
  // Initialized in __asan_init
  // kHighMemEnd = 0x7fffffffffffULL;
  // kMidMemBeg = 0x3000000000ULL;
  // kMidMemEnd = 0x4fffffffffULL;

  // mmap to file
  // fd = MmapFixed(beg, size, 0);
  // low_mmap_fd = MmapFixed(kLowShadowBeg, kLowShadowEnd - kLowShadowBeg + 1, 0, "low_shadow");
  // high_mmap_fd = MmapFixed(kHighShadowBeg, kHighShadowEnd - kHighShadowBeg + 1, 0, "high_shadow");

  // A crucial work
  uptr mmap_size = kLowShadowEnd - kLowShadowBeg + 1;
  uptr start_addr = kLowShadowBeg;
  char buff[30] = {0};
  while (start_addr < kHighShadowEnd) {
    internal_lltoa(start_addr, buff, 16);
    MmapFixed(start_addr, mmap_size, 0, buff);
    start_addr += mmap_size;
  }
  size = kHighShadowEnd - kLowShadowBeg;
  ASAN_SHADOW_BEG = kLowShadowBeg - ASAN_PAGE_SIZE;
  ASAN_SHADOW_END = kHighShadowEnd + ASAN_PAGE_SIZE;
  beg = ASAN_SHADOW_BEG;
  end = ASAN_SHADOW_END;
  // Too large size, we do not initialize it at the beginning
  // internal_memset((void *)kLowShadowBeg, UNINITIALIZED_SHADOW, size);
  return 0;
}

int AsanShadowMem::DeleteShadow() {}

su_t AsanShadowMem::PoisonUnitHeadBytes(su_t old_value, su_t new_value, int affected_bytes) {
  CHECK((affected_bytes <= scale_factor && affected_bytes > 0));
  su_t ret_value = old_value;
  if (affected_bytes == scale_factor) {
    ret_value = new_value;
  } else {
    // the first several bytes are poisoned
    // but the last several bytes maybe still usable
    // To supress FPs, all 8 bytes are treated as not poisoned
    if (old_value == 0) {               /*do nothing*/
    } else if (old_value > 0) {
      if (old_value > affected_bytes) { /*do nothing since several bytes are still accessible*/
      } else {
        ret_value = new_value;
      }
    } else {
      if (affected_bytes > half_scale_factor) {
        ret_value = new_value;
      } else {
        if (old_value == UNINITIALIZED_SHADOW) {
          ret_value = new_value;
        } else { /*do nothing*/
        }
      }
    }
  }
  return ret_value;
}

su_t AsanShadowMem::PoisonUnitTailBytes(su_t old_value, su_t new_value, int affected_bytes) {
  CHECK((affected_bytes <= scale_factor && affected_bytes > 0));
  su_t ret_value = old_value;
  if (affected_bytes == scale_factor) {
    // the beg is aligned
    ret_value = new_value;
  } else {
    // the last several bytes are poisoned
    // but several first bytes are still usable
    if (old_value == 0) {
      ret_value = scale_factor - affected_bytes;
    } else if (old_value > 0) {
      if (old_value + affected_bytes > scale_factor) {
        // there is overlap between usable and poisoned
        ret_value = scale_factor - affected_bytes;
      } else { /*do nothing since no overlap*/
      }
    } else {
      // the 8 bytes are not readable anyway
      if (affected_bytes > half_scale_factor) {
        ret_value = new_value;
      } else {
        if (old_value == UNINITIALIZED_SHADOW) {
          // uninitialized yet, to be poisoned
          ret_value = new_value;
        } else { /*do nothing*/
        }
      }
    }
  }
  return ret_value;
}

su_t AsanShadowMem::PoisonUnit(su_t old_value, su_t new_value, int beg_byte, int end_byte) {
  CHECK(beg_byte >= 0);
  CHECK(end_byte < scale_factor);
  CHECK(beg_byte <= end_byte);
  su_t ret_value = old_value;
  if (old_value == 0) {
    if (end_byte < scale_factor - 1) {
      // some bytes at tail are valid
      ret_value = 0;
    } else {
      // util beg_byte, all bytes are valid
      // when beg_byte == 0, all utils are poisoned
      if (beg_byte == 0) ret_value = new_value;
      else ret_value = scale_factor - beg_byte;
    }
  } else {
    if (old_value > 0) {
      // consider the overlap [0, old_value-1] [beg_byte, end_byte]
      if (beg_byte > old_value - 1) {
        // no overlap and do nothing
      } else if (end_byte >= old_value - 1) {
        // overlap [beg_byte, old_value-1]
        if (beg_byte > 0)
          ret_value = beg_byte;
        else
          ret_value = new_value;
      } else {
        // some bytes in [end_byte, old_value - 1] are valid
        // do nothing
      }
    }
  }
  return ret_value;
}

int AsanShadowMem::PoisonMem(const void *membeg, uptr size, su_t v) {
  uptr tmpbeg = (uptr)membeg;
  uptr tmpend = tmpbeg + size - 1;
  uptr shadowbeg = MEM_TO_SHADOW(tmpbeg);
  // the affected memory may not be aligned to the scale_factor
  uptr alignedbeg = RoundDownTo(tmpbeg, scale_factor);
  // uptr begAffectedBytes = scale_factor - (tmpbeg - alignedbeg); // [1, scale_factor]
  uptr beg_byte = tmpbeg - alignedbeg;

  uptr alignedend = RoundUpTo(tmpend, scale_factor) - 1;
  if (alignedend < tmpend) alignedend += scale_factor;
  uptr end_byte = tmpend - (alignedend + 1 - scale_factor);

  uptr shadowSize = (alignedend - alignedbeg) / scale_factor + 1;
  uptr shadowend = shadowbeg + shadowSize - 1;

  // TODO: the following code should be thread safe
  if (shadowSize > 1) {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    firstByteValue = PoisonUnitTailBytes(firstByteValue, v, scale_factor - beg_byte);

    su_t lastByteValue = *(su_t *)(shadowend);
    lastByteValue = PoisonUnitHeadBytes(lastByteValue, v, end_byte + 1);
    internal_memset((void *)shadowbeg, v, shadowSize);
    // end TODO
    if (firstByteValue != v) {
      *((su_t *)shadowbeg) = firstByteValue;
    }
    if (lastByteValue != v) {
      *((su_t *)shadowend) = lastByteValue;
    }
  } else {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    *((su_t *)shadowbeg) = PoisonUnit(firstByteValue, v, beg_byte, end_byte);
  }
  return 0;
}

su_t AsanShadowMem::UnpoisonUnitHeadBytes(su_t old_value, int affected_bytes) {
  CHECK((affected_bytes <= scale_factor && affected_bytes > 0));
  if (old_value == 0)
    return 0;
  else if (affected_bytes == scale_factor)
    return 0;
  else if (old_value < 0 || old_value >= scale_factor)
    return affected_bytes; // affected_bytes < scale_factor
  else
    return Max<su_t>(affected_bytes, old_value);
}

su_t AsanShadowMem::UnpoisonUnitTailBytes(su_t old_value, int affected_bytes) {
  CHECK((affected_bytes < scale_factor && affected_bytes > 0));
  return 0;
}

su_t AsanShadowMem::UnpoisonUnit(su_t old_value, int beg_byte, int end_byte) {
  CHECK(beg_byte >= 0);
  CHECK(end_byte < scale_factor);
  CHECK(beg_byte <= end_byte);
  su_t ret_value = old_value;
  if (old_value == 0) {
    // unpoison the same memory again?
    // do nothing
  } else {
    if (old_value < 0 || old_value >= scale_factor) {
      // all bytes are invalid
      ret_value = (end_byte + 1) & 0x07;
    } else {
      // consider the overlap between [0, old_value - 1] and [beg_byte, end_byte]
      if (end_byte <= old_value - 1) {
        // do nothing
      } else {
        // no matter if there is a overlap, we set [0, end_byte] unpoisoned
        ret_value = end_byte + 1;
      }
    }
  }
  return ret_value;
}

int AsanShadowMem::UnpoisonMem(const void *membeg, uptr size) {
  uptr tmpbeg = (uptr)membeg;
  uptr tmpend = tmpbeg + size - 1;
  uptr shadowbeg = MEM_TO_SHADOW(tmpbeg);
  // the affected memory may not be aligned to the scale_factor
  uptr alignedbeg = RoundDownTo(tmpbeg, scale_factor);
  // uptr begAffectedBytes = scale_factor - (tmpbeg - alignedbeg); // [1, scale_factor]
  uptr beg_byte = tmpbeg - alignedbeg;

  uptr alignedend = RoundUpTo(tmpend, scale_factor) - 1;
  if (alignedend < tmpend) alignedend += scale_factor;
  uptr end_byte = tmpend - (alignedend + 1 - scale_factor);

  uptr shadowSize = (alignedend - alignedbeg) / scale_factor + 1;
  uptr shadowend = shadowbeg + shadowSize - 1;

  if (shadowSize > 1) {
    // su_t firstByteValue = *(su_t*)(shadowbeg);
    // if (begAffectedBytes != scale_factor) {
    //   firstByteValue = UnpoisonUnitTailBytes(firstByteValue, begAffectedBytes);
    // }
    su_t firstByteValue = 0;  // must be set to 0

    su_t lastByteValue = 0;
    if (end_byte != scale_factor - 1) {
      lastByteValue = *(su_t *)(shadowend);
      lastByteValue = UnpoisonUnitHeadBytes(lastByteValue, end_byte + 1);
    }
    internal_memset((void *)shadowbeg, 0, shadowSize);
    if (lastByteValue != 0) {
      *((su_t *)shadowend) = lastByteValue;
    }
  } else {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    *((su_t *)shadowbeg) = UnpoisonUnit(firstByteValue, beg_byte, end_byte);
  }
  return 0;
}

bool AsanShadowMem::IsPoisonedUnit(su_t old_value, int beg_byte, int end_byte) {
  CHECK(beg_byte >= 0);
  CHECK(end_byte < scale_factor);
  CHECK(beg_byte <= end_byte);
  if (old_value == 0)
    return false;
  else if (old_value < 0 || old_value > scale_factor - 1)
    return true;
  else {
    // [0, old_value-1] are valid
    if (end_byte > old_value - 1)
      return true;
    else
      return false;
  }
}

bool AsanShadowMem::IsPoisonedMem(const void *membeg, uptr size) {
  if (size == 0) return false;
  uptr tmpbeg = (uptr)membeg;
  uptr tmpend = tmpbeg + size - 1;
  uptr shadowbeg = MEM_TO_SHADOW(tmpbeg);
  // the affected memory may not be aligned to the scale_factor
  uptr alignedbeg = RoundDownTo(tmpbeg, scale_factor);
  // uptr begAffectedBytes = scale_factor - (tmpbeg - alignedbeg); // [1, scale_factor]
  uptr beg_byte = tmpbeg - alignedbeg;

  uptr alignedend = RoundUpTo(tmpend, scale_factor) - 1;
  if (alignedend < tmpend) alignedend += scale_factor;
  uptr end_byte = tmpend - (alignedend + 1 - scale_factor);

  uptr shadowSize = (alignedend - alignedbeg) / scale_factor + 1;
  uptr shadowend = shadowbeg + shadowSize - 1;

  if (shadowSize > 1) {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    if (IsPoisonedUnit(firstByteValue, beg_byte, scale_factor - 1)) return true;
    su_t lastByteValue = *(su_t *)(shadowend);
    if (IsPoisonedUnit(lastByteValue, 0, end_byte)) return true;
    // all bytes in [shadowbeg + 1, shadowend - 1] should be 0
    for (uptr i = shadowbeg + 1; i < shadowend - 1; ++i) {
      if (*((su_t *)i) != 0) return true;
    }
    return false;
  } else {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    return IsPoisonedUnit(firstByteValue, beg_byte, end_byte);
  }
}

bool AsanShadowMem::IsFreeAfterUse(const void* membeg) {
  uptr shadowbeg = MEM_TO_SHADOW((uptr)membeg);
  return kAsanHeapFreeMagic == *((su_t*)shadowbeg);
}

uptr AsanShadowMem::GetPoisonedAddr(const void *membeg, uptr size) {
  if (size == 0) return false;
  uptr tmpbeg = (uptr)membeg;
  uptr tmpend = tmpbeg + size - 1;
  uptr shadowbeg = MEM_TO_SHADOW(tmpbeg);
  // the affected memory may not be aligned to the scale_factor
  uptr alignedbeg = RoundDownTo(tmpbeg, scale_factor);
  // uptr begAffectedBytes = scale_factor - (tmpbeg - alignedbeg); // [1, scale_factor]
  uptr beg_byte = tmpbeg - alignedbeg;

  uptr alignedend = RoundUpTo(tmpend, scale_factor) - 1;
  if (alignedend < tmpend) alignedend += scale_factor;
  uptr end_byte = tmpend - (alignedend + 1 - scale_factor);

  uptr shadowSize = (alignedend - alignedbeg) / scale_factor + 1;
  uptr shadowend = shadowbeg + shadowSize - 1;

  if (shadowSize > 1) {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    if (IsPoisonedUnit(firstByteValue, beg_byte, scale_factor - 1)) return tmpbeg;

    su_t lastByteValue = *(su_t *)(shadowend);
    if (IsPoisonedUnit(lastByteValue, 0, end_byte)) return tmpend;
    // all bytes in [shadowbeg + 1, shadowend - 1] should be 0
    // TODO: read 8 bytes per iteration
    for (uptr i = shadowbeg + 1; i < shadowend - 1; ++i) {
      if (*((su_t *)i) != 0) return SHADOW_TO_MEM(i);
    }
    return 0;
  } else {
    su_t firstByteValue = *(su_t *)(shadowbeg);
    if (IsPoisonedUnit(firstByteValue, beg_byte, end_byte)) {
      return tmpbeg;
    }
    else {
      return 0;
    }
  }
}

} // namespace __sanitizer